Spatial Gaussian Process inference in PyMC3

This is the first step in modelling Species occurrence. The good news is that MCMC works, The bad one is that it's computationally intense.


In [10]:
# Load Biospytial modules and etc.
%matplotlib inline
import sys
sys.path.append('/apps/external_plugins/spystats/')
#import django
#django.setup()
import pandas as pd
import matplotlib.pyplot as plt
## Use the ggplot style
plt.style.use('ggplot')
import numpy as np

In [11]:
## Model Specification
import pymc3 as pm
from spystats import tools

Simulated gaussian data


In [12]:
sigma=3.5
range_a=10.13
kappa=3.0/2.0
#ls = 0.2
#tau = 2.0
cov = sigma * pm.gp.cov.Matern32(2, range_a,active_dims=[0,1])

In [13]:
n = 10
grid = tools.createGrid(grid_sizex=n,grid_sizey=n,minx=0,miny=0,maxx=50,maxy=50)

In [14]:
K = cov(grid[['Lon','Lat']].values).eval()
sample = pm.MvNormal.dist(mu=np.zeros(K.shape[0]), cov=K).random(size=1)
grid['Z'] = sample

In [15]:
plt.figure(figsize=(14,4))
plt.imshow(grid.Z.values.reshape(n,n),interpolation=None)


Out[15]:
<matplotlib.image.AxesImage at 0x7fcd948daf50>

In [16]:
print("sigma: %s, phi: %s"%(sigma,range_a))


sigma: 3.5, phi: 10.13

In [18]:
## Analysis, GP only one parameter to fit
# The variational method is much beter.
with pm.Model() as model:
    
    #sigma = 1.0
    sigma = pm.Uniform('sigma',0,4)
    phi = pm.Normal('phi',mu=8,sd=3)
#    phi = pm.Uniform('phi',5,10)

    cov = sigma * pm.gp.cov.Matern32(2,phi,active_dims=[0,1])
    K = cov(grid[['Lon','Lat']].values)
    y_obs = pm.MvNormal('y_obs',mu=np.zeros(n*n),cov=K,observed=grid.Z)
    
    #gp = pm.gp.Latent(cov_func=cov,observed=sample)
    # Use elliptical slice sampling
    #ess_step = pm.EllipticalSlice(vars=[f_sample], prior_cov=K)
    #ess_Step = pm.HamiltonianMC()
    #%time trace = pm.sample(5000)
    ## Variational
    %time results = pm.fit()


Average Loss = 139.51: 100%|██████████| 10000/10000 [00:28<00:00, 348.57it/s]
CPU times: user 1min 36s, sys: 2.14 s, total: 1min 39s
Wall time: 30.6 s

Diagnostics

For one parameter it took around 1.3 minutes For two parameters it took 4min 27 seconds


In [19]:
from pymc3 import find_MAP
map_estimate = find_MAP(model=model)


logp = -139.58, ||grad|| = 2.6669e-05: 100%|██████████| 12/12 [00:00<00:00, 341.72it/s]  

Simulated Poisson data with latent Gaussian Field


In [20]:
np.random.seed(1234)

sigma=3.5
range_a=10.13
kappa=3.0/2.0
#ls = 0.2
alpha = 0.0
cov = sigma * pm.gp.cov.Matern32(2, range_a,active_dims=[0,1])
n = 20
grid = tools.createGrid(grid_sizex=n,grid_sizey=n,minx=0,miny=0,maxx=20,maxy=20)
K = cov(grid[['Lon','Lat']].values).eval()
pfield = pm.MvNormal.dist(mu=np.zeros(K.shape[0]), cov=K).random(size=1)

poiss_data = np.exp(alpha + pfield)

grid['Z'] = poiss_data
#grid['Z'] = pfield
plt.figure(figsize=(14,4))
plt.imshow(grid.Z.values.reshape(n,n),interpolation=None)
plt.colorbar()
print("sigma: %s, phi: %s"%(sigma,range_a))


sigma: 3.5, phi: 10.13

In [21]:
## Analysis, GP only one parameter to fit
# The variational method is much beter.
from pymc3.variational.callbacks import CheckParametersConvergence

with pm.Model() as model:
    sigma=3.5
    range_a=10.13
    
    
    #sigma = pm.Uniform('sigma',0,4)
    #phi = pm.HalfNormal('phi',mu=8,sd=3)
    #phi = pm.Uniform('phi',6,12)
    phi = pm.Uniform('phi',5,15)
    cov = sigma * pm.gp.cov.Matern32(2,phi,active_dims=[0,1])
    #K = cov(grid[['Lon','Lat']].values)
    #phiprint = tt.printing.Print('phi')(phi)
    
    ## The latent function
    gp = pm.gp.Latent(cov_func=cov)
    
    ## I don't know why this
    f = gp.prior("latent_field", X=grid[['Lon','Lat']].values,reparameterize=True)
    
    #f_print = tt.printing.Print('latent_field')(f)
    
    y_obs = pm.Poisson('y_obs',mu=f,observed=grid.Z)
    
    #y_obs = pm.MvNormal('y_obs',mu=np.zeros(n*n),cov=K,observed=grid.Z)
    
    #gp = pm.gp.Latent(cov_func=cov,observed=sample)
    # Use elliptical slice sampling
    #ess_step = pm.EllipticalSlice(vars=[f_sample], prior_cov=K)
    #step = pm.HamiltonianMC()
    #step = pm.Metropolis()
    #%time trace = pm.sample(5000,step)#,tune=0,chains=1)
    ## Variational
    
    %time mean_field = pm.fit(method='advi', callbacks=[CheckParametersConvergence()])


Average Loss = inf: 100%|██████████| 10000/10000 [03:30<00:00, 47.47it/s]
CPU times: user 12min 55s, sys: 12.9 s, total: 13min 8s
Wall time: 3min 32s

ESsta dando un monton de inf en averafe lost


In [22]:
# pm.traceplot(trace)

In [23]:
#for RV in model.basic_RVs:
#    print(RV.name, RV.logp(model.test_point))

In [24]:
from pymc3 import find_MAP
map_estimate = find_MAP(model=model)
map_estimate


logp = -inf, ||grad|| = 214.35: 100%|██████████| 3/3 [00:00<00:00, 52.60it/s]
Out[24]:
{'latent_field': array([-0.3834217 , -0.41007374, -0.43583168, -0.4601648 , -0.4825283 ,
        -0.50238856, -0.51925151, -0.53269178, -0.54237913, -0.54809885,
        -0.54976291, -0.54741049, -0.54119765, -0.53137821, -0.51827874,
        -0.502271  , -0.48374474, -0.46308261, -0.44063796, -0.41671505,
        -0.42857215, -0.45936469, -0.48921593, -0.51749078, -0.54353027,
        -0.56668294, -0.58634113, -0.6019791 , -0.61318854, -0.61970607,
        -0.6214286 , -0.61841367, -0.61086516, -0.59910711, -0.58355045,
        -0.56465742, -0.54290813, -0.51877131, -0.49268106, -0.46502698,
        -0.4771009 , -0.51254917, -0.54702953, -0.57978574, -0.61002385,
        -0.63695061, -0.65981963, -0.67798131, -0.69093038, -0.69834355,
        -0.70010052, -0.69628413, -0.68716007, -0.67314079, -0.65474072,
        -0.63253056, -0.60709641, -0.57900781, -0.54879816, -0.51696759,
        -0.52875988, -0.56939683, -0.60906844, -0.64687973, -0.68187894,
        -0.71310384, -0.73963986, -0.76068513, -0.77561442, -0.78403094,
        -0.78579536, -0.78102558, -0.77006775, -0.753446  , -0.73180203,
        -0.70583633, -0.67625935, -0.64375753, -0.60897749, -0.57253092,
        -0.58313849, -0.62949584, -0.67492884, -0.71838623, -0.75873437,
        -0.7948123 , -0.82550359, -0.84982063, -0.8669905 , -0.87652665,
        -0.87826937, -0.87238496, -0.85932416, -0.83975158, -0.81446405,
        -0.7843152 , -0.75015737, -0.71280566, -0.67302445, -0.63153252,
        -0.63963485, -0.69221609, -0.74396115, -0.79364583, -0.83993079,
        -0.8814253 , -0.91677457, -0.94476712, -0.9644499 , -0.97522767,
        -0.97691963, -0.96975766, -0.95432684, -0.93146613, -0.90215791,
        -0.86743175, -0.82829585, -0.78569948, -0.74052295, -0.6935867 ,
        -0.69743107, -0.75667547, -0.81522527, -0.8716685 , -0.92443791,
        -0.97188341, -1.01237531, -1.04443797, -1.06690267, -1.07904518,
        -1.08066365, -1.07207541, -1.05403386, -1.02758897, -0.99393856,
        -0.95430764, -0.90986892, -0.86170482, -0.81080429, -0.75808186,
        -0.75547714, -0.82171494, -0.88745426, -0.95108237, -1.01078467,
        -1.06462644, -1.11067218, -1.14714583, -1.17262985, -1.18625499,
        -1.18779906, -1.17768205, -1.15685855, -1.12662312, -1.08841327,
        -1.04366293, -0.99371195, -0.93976864, -0.8829174 , -0.82415544,
        -0.81248903, -0.88588981, -0.95903608, -1.0301022 , -1.09700844,
        -1.15751575, -1.20936452, -1.25045977, -1.27911844, -1.2943307 ,
        -1.29585551, -1.28421356, -1.26057872, -1.22650731, -1.18368892,
        -1.13377349, -1.07826602, -1.01848873, -0.95560251, -0.89067041,
        -0.86696701, -0.94748525, -1.0280252 , -1.10653457, -1.18064915,
        -1.2478083 , -1.30542234, -1.35109707, -1.38292589, -1.39980673,
        -1.40148172, -1.38852773, -1.36229871, -1.32459367, -1.27735606,
        -1.22245627, -1.16156166, -1.09609463, -1.02727033, -0.95619404,
        -0.91723968, -1.00456366, -1.09219427, -1.17783608, -1.25881674,
        -1.33223649, -1.39518823, -1.44504731, -1.47984353, -1.49844044,
        -1.50056596, -1.48679995, -1.45852013, -1.4177019 , -1.36652861,
        -1.30709644, -1.24123268, -1.17044664, -1.09598822, -1.0189791 ,
        -0.96153778, -1.05504857, -1.14913202, -1.24123244, -1.32834348,
        -1.40722566, -1.47472324, -1.52816716, -1.5656089 , -1.58592347,
        -1.58882795, -1.57486514, -1.54535455, -1.50226434, -1.4479516 ,
        -1.38472874, -1.31457455, -1.2390684 , -1.15948115, -1.07694563,
        -0.9981003 , -1.09684859, -1.19638911, -1.29389613, -1.38601768,
        -1.46922797, -1.54030971, -1.59664061, -1.63633758, -1.65831217,
        -1.66227568, -1.64871695, -1.61885662, -1.57456188, -1.51817949,
        -1.45218142, -1.37865899, -1.29924057, -1.2152088 , -1.12772698,
        -1.02531203, -1.12802344, -1.23167333, -1.33318691, -1.42889708,
        -1.51515882, -1.58878491, -1.64724434, -1.68874315, -1.71224616,
        -1.71746594, -1.70483347, -1.67545431, -1.63104488, -1.57383329,
        -1.50627834, -1.43049519, -1.34814291, -1.2605404 , -1.16892443,
        -1.04186134, -1.14698901, -1.25310715, -1.35696214, -1.45471005,
        -1.54267491, -1.61775005, -1.67752061, -1.72029704, -1.74510821,
        -1.75167517, -1.74037523, -1.71219823, -1.66869185, -1.61188312,
        -1.54404709, -1.46720822, -1.383026  , -1.29290178, -1.19825954,
        -1.04688492, -1.15271213, -1.25949614, -1.36390892, -1.46206281,
        -1.55032394, -1.62569797, -1.68590131, -1.72935556, -1.75515707,
        -1.76303848, -1.75332723, -1.72690052, -1.68513041, -1.62979463,
        -1.56286715, -1.48620588, -1.40144219, -1.31008571, -1.21381421,
        -1.04006119, -1.14481706, -1.2504259 , -1.35358957, -1.45048928,
        -1.53760713, -1.61209312, -1.67180393, -1.71527075, -1.74165279,
        -1.75069002, -1.74265706, -1.71831461, -1.67884841, -1.62577228,
        -1.56076598, -1.48549435, -1.40154091, -1.31050949, -1.21427032,
        -1.02164367, -1.12360942, -1.22626008, -1.32642291, -1.42044931,
        -1.5050096 , -1.57743066, -1.63571442, -1.6784939 , -1.70497992,
        -1.71490973, -1.70849759, -1.6863829 , -1.6495664 , -1.59932215,
        -1.53708351, -1.46434069, -1.38261679, -1.2935607 , -1.1991372 ,
        -0.99244655, -1.09005747, -1.18812065, -1.28367348, -1.37333427,
        -1.45402937, -1.52328902, -1.57926012, -1.6206658 , -1.64675464,
        -1.65724983, -1.65229927, -1.63242332, -1.59845467, -1.5514667 ,
        -1.49269697, -1.42349177, -1.34530734, -1.25978784, -1.16890262,
        -0.95378601, -1.04573643, -1.13784652, -1.22743088, -1.31146848,
        -1.38719639, -1.45236142, -1.50525023, -1.54466449, -1.569879  ,
        -1.58059522, -1.57689391, -1.55918607, -1.5281605 , -1.48472872,
        -1.42997515, -1.36512895, -1.29157618, -1.21091932, -1.12506583]),
 'latent_field_rotated_': array([-0.20494749, -0.10095638, -0.07905667, -0.07580585, -0.07476589,
        -0.07234559, -0.06930893, -0.06557659, -0.06135392, -0.05679548,
        -0.05207112, -0.04732661, -0.0426775 , -0.03820181, -0.03394069,
        -0.02990269, -0.02606947, -0.02240043, -0.01883474, -0.01529026,
        -0.1524021 , -0.06731612, -0.07468189, -0.07543851, -0.0777946 ,
        -0.07855393, -0.07839125, -0.0770924 , -0.07478364, -0.07157796,
        -0.06765783, -0.06322154, -0.05846621, -0.05356574, -0.04865757,
        -0.04385315, -0.03920251, -0.03488334, -0.03073877, -0.0289584 ,
        -0.08196834, -0.04380591, -0.05198339, -0.05407075, -0.05768883,
        -0.0601566 , -0.06187716, -0.06257812, -0.06226409, -0.06097348,
        -0.05883516, -0.05602586, -0.05274346, -0.04917997, -0.04550224,
        -0.04187779, -0.03844754, -0.03567951, -0.03388969, -0.03450034,
        -0.07549382, -0.04041733, -0.04939809, -0.05203178, -0.05621666,
        -0.0592529 , -0.06148518, -0.06259721, -0.06256401, -0.06142302,
        -0.05932925, -0.05650403, -0.05319491, -0.04963718, -0.04603131,
        -0.04257153, -0.03942811, -0.03710138, -0.03584006, -0.03651759,
        -0.07741419, -0.04288697, -0.05224779, -0.05554156, -0.06033132,
        -0.06390916, -0.06653767, -0.06784932, -0.06779203, -0.06641879,
        -0.06393702, -0.06064099, -0.05684841, -0.0528433 , -0.04884744,
        -0.04505551, -0.04162486, -0.03905525, -0.03758411, -0.03810103,
        -0.07829439, -0.04485248, -0.0545332 , -0.05854348, -0.06396671,
        -0.06813076, -0.07119061, -0.07271434, -0.07261209, -0.07095851,
        -0.06803011, -0.06421622, -0.0599229 , -0.05548677, -0.05114331,
        -0.04707624, -0.043418  , -0.04065068, -0.03898145, -0.03929501,
        -0.07792118, -0.04622849, -0.05608995, -0.06081812, -0.06685529,
        -0.07161358, -0.07512219, -0.07686753, -0.07671013, -0.07475582,
        -0.07136042, -0.06702551, -0.06226498, -0.05746403, -0.05285574,
        -0.0485991 , -0.04478743, -0.04186533, -0.03999763, -0.04006557,
        -0.07597667, -0.04671387, -0.05658838, -0.06196443, -0.06851482,
        -0.07379484, -0.0777125 , -0.07966218, -0.07944842, -0.07723986,
        -0.07344654, -0.06866878, -0.06355752, -0.05852437, -0.05378088,
        -0.04945883, -0.04560439, -0.04260374, -0.04055684, -0.04031915,
        -0.07237884, -0.0461642 , -0.05584601, -0.0617139 , -0.06855936,
        -0.07413825, -0.07828469, -0.08034239, -0.08006642, -0.07780343,
        -0.0738728 , -0.06885952, -0.06362342, -0.05853456, -0.05380256,
        -0.04955165, -0.04577845, -0.04278799, -0.04058279, -0.03995149,
        -0.06715175, -0.0445    , -0.05377308, -0.05990921, -0.06671818,
        -0.07217768, -0.07608877, -0.07795095, -0.07765217, -0.07577966,
        -0.0723036 , -0.0675128 , -0.06242755, -0.05742752, -0.05283808,
        -0.0487815 , -0.04520722, -0.04231654, -0.03998046, -0.03883516,
        -0.0604601 , -0.04173981, -0.05041982, -0.05659094, -0.06300666,
        -0.06785047, -0.0709469 , -0.07215064, -0.07215994, -0.07098108,
        -0.06854363, -0.0647727 , -0.05996689, -0.05519208, -0.05082483,
        -0.04706693, -0.04377374, -0.0410345 , -0.03859492, -0.03681317,
        -0.052581  , -0.03795681, -0.04591476, -0.05191578, -0.05763035,
        -0.06152919, -0.06370612, -0.06484938, -0.0650707 , -0.06438871,
        -0.062751  , -0.0600587 , -0.05626533, -0.05177071, -0.04773291,
        -0.04435677, -0.04137886, -0.03873919, -0.03612414, -0.03365744,
        -0.04387162, -0.03322592, -0.04035997, -0.04597342, -0.0507089 ,
        -0.05356581, -0.05544946, -0.05650033, -0.05683618, -0.05649136,
        -0.05543684, -0.05358695, -0.0508481 , -0.04728798, -0.04357812,
        -0.04073138, -0.03801762, -0.0353799 , -0.03239432, -0.02911553,
        -0.03483696, -0.02773845, -0.03392948, -0.03899014, -0.04261998,
        -0.04501021, -0.04657076, -0.04746728, -0.04782586, -0.04769797,
        -0.04707682, -0.04589667, -0.04404916, -0.04148532, -0.0386194 ,
        -0.03624781, -0.0336891 , -0.03100492, -0.02777698, -0.02379231,
        -0.02609093, -0.02191012, -0.02701169, -0.03140952, -0.03433961,
        -0.03620604, -0.03739294, -0.03807637, -0.03837947, -0.03836164,
        -0.03803702, -0.03737625, -0.03630617, -0.03473215, -0.03275415,
        -0.03079089, -0.02832872, -0.02550893, -0.0221603 , -0.01826185,
        -0.01817713, -0.01625148, -0.02029448, -0.02393898, -0.02613643,
        -0.02745715, -0.02825734, -0.02870348, -0.0289004 , -0.0288997 ,
        -0.0287207 , -0.02835784, -0.027784  , -0.0269515 , -0.02578982,
        -0.02421586, -0.02212435, -0.01953038, -0.01637254, -0.01297872,
        -0.01143394, -0.01101474, -0.01404649, -0.01687404, -0.01834726,
        -0.01915075, -0.0196022 , -0.01983593, -0.01992574, -0.01990414,
        -0.01978117, -0.01955033, -0.01919038, -0.01866466, -0.01791704,
        -0.01687472, -0.01543594, -0.01355939, -0.01113939, -0.00838121,
        -0.00616378, -0.00649324, -0.00857165, -0.01055798, -0.01136989,
        -0.01174511, -0.01193571, -0.0120232 , -0.01204682, -0.01202089,
        -0.01194826, -0.01182317, -0.0116318 , -0.01135068, -0.01094265,
        -0.01035493, -0.00950578, -0.00832748, -0.00669003, -0.00467902,
        -0.00256191, -0.00300083, -0.00419494, -0.00536893, -0.00565432,
        -0.00574979, -0.00579273, -0.00580659, -0.00580326, -0.00578525,
        -0.00575197, -0.00569985, -0.00562215, -0.00550771, -0.00533822,
        -0.00508544, -0.00470068, -0.0041258 , -0.0032451 , -0.00204349,
        -0.00062085, -0.00081911, -0.0012532 , -0.00171723, -0.00171474,
        -0.00171277, -0.00171016, -0.0017067 , -0.001702  , -0.00169554,
        -0.00168653, -0.0016738 , -0.00165547, -0.00162856, -0.00158804,
        -0.00152557, -0.00142517, -0.00126255, -0.00098199, -0.00053731]),
 'phi': array(10.0),
 'phi_interval__': array(0.0)}

In [87]:
plt.imshow(map_estimate['latent_field'].reshape(20,20))


Out[87]:
<matplotlib.image.AxesImage at 0x7fa2c84d8a50>

In [26]:
pm.plot_posterior(mean_field.sample(10), color='LightSeaGreen');


Examine actual posterior distribution

The posterior is analytically tractable so we can compute the posterior mean explicitly. Rather than computing the inverse of the covariance matrix K, we use the numerically stable calculation described Algorithm 2.1 in the book “Gaussian Processes for Machine Learning” (2006) by Rasmussen and Williams, which is available online for free.


In [ ]:
fig, ax = plt.subplots(figsize=(14, 6));
ax.scatter(X0, f, s=40, color='b', label='True points');

# Analytically compute posterior mean
## This is the cholesky decomposition of the Covariance Matrix with kernel nugget
L = np.linalg.cholesky(K_noise.eval())
## Faith step, This solves the base x's such that Lx = f and the uses x for solving y's such that L.T y = x
alpha = np.linalg.solve(L.T, np.linalg.solve(L, f))
## Multiply the posterior (ALgorithm 2.1 in Rasmunssen)
## Using the "extended matrix" K_s
post_mean = np.dot(K_s.T.eval(), alpha)

ax.plot(X0, post_mean, color='g', alpha=0.8, label='Posterior mean');

ax.set_xlim(0, 3);
ax.set_ylim(-2, 2);
ax.legend();

Ok, it's good to have the analitical solution but not always possible sooooo. Let's do some computing.

Model in PyM3


In [ ]:
with pm.Model() as model:
    # The actual distribution of f_sample doesn't matter as long as the shape is right since it's only used
    # as a dummy variable for slice sampling with the given prior
    ### From doc:
    ### 
    
    
    
    f_sample = pm.Flat('f_sample', shape=(n, ))

    ## Actually, pm.Flat is a zero array of shape n
    # Likelihood
    ## The covariance is only in the diagonal
    
    y = pm.MvNormal('y', observed=sample, mu=f_sample, cov=noise * tt.eye(n), shape=n)

    # Interpolate function values using noisy covariance matrix
    ## Deterministic allows to compose (do algebra) with RV in many different ways. 
    ##While these transformations work seamlessly, its results are not stored automatically. 
    ##Thus, if you want to keep track of a transformed variable, you have to use pm.Determinstic:
    ## from http://docs.pymc.io/notebooks/api_quickstart.html
    
    ## So in this case is transforming the rv into:
    ## the low triangular cholesky decomposition of the Covariance with nugget
    L = tt.slinalg.cholesky(K_noise)
    ## So this is for calculating the "kernel" part of the MVN i.e. (mu -x).T * (LL.T)^-1 * (mu-x)
    ## but considering mu = 0 we have that x = linalg.solve(L,y) (because Lx = y)
    ## Then, L.T*x)
    f_pred = pm.Deterministic('f_pred', tt.dot(K_s.T, tt.slinalg.solve(L.T, tt.slinalg.solve(L, f_sample))))

    # Use elliptical slice sampling
    ess_step = pm.EllipticalSlice(vars=[f_sample], prior_cov=K_stable)
    trace = pm.sample(5000, start=model.test_point, step=[ess_step], progressbar=False, random_seed=1)

Evaluate posterior fit

The posterior samples are consistent with the analytically derived posterior and behaves how one would expect–narrower near areas with lots of observations and wider in areas with more uncertainty.


In [ ]:
fig, ax = plt.subplots(figsize=(14, 6));
for idx in np.random.randint(4000, 5000, 500):
    ax.plot(X0, trace['f_pred'][idx],  alpha=0.02, color='navy')
ax.scatter(X0, f, s=40, color='k', label='True points');
ax.plot(X0, post_mean, color='g', alpha=0.8, label='Posterior mean');
ax.legend();
ax.set_xlim(0, 3);
ax.set_ylim(-2, 2);

In [ ]:
pm.traceplot(trace)

Clasification

In Gaussian process classification, the likelihood is not normal and thus the posterior is not analytically tractable. The prior is again a multivariate normal with covariance matrix K, and the likelihood is the standard likelihood for logistic regression: \begin{equation} L(y | f) = \Pi_n \sigma(y_n, f_n) \end{equation}

Generate some example data

We generate random samples from a Gaussian process, assign any points greater than zero to a “positive” class, and assign all other points to a “negative” class.


In [ ]:
np.random.seed(5)
f = np.random.multivariate_normal(mean=np.zeros(n), cov=K_stable.eval())

# Separate data into positive and negative classes
f[f > 0] = 1
f[f <= 0] = 0

fig, ax = plt.subplots(figsize=(14, 6));
for idx in np.random.randint(4000, 5000, 500):
    ax.plot(X, trace['f_pred'][idx],  alpha=0.02, color='navy')
ax.scatter(X0, f, s=40, color='k', label='True points');
ax.plot(X, post_mean, color='g', alpha=0.8, label='Posterior mean');
ax.legend();
ax.set_xlim(0, 3);
ax.set_ylim(-2, 2);

Sample from posterior distribution


In [ ]:
with pm.Model() as model:
    # Again, f_sample is just a dummy variable
    f_sample = pm.Flat('f_sample', shape=n)
    f_transform = pm.invlogit(f_sample)

    # Binomial likelihood
    y = pm.Binomial('y', observed=f, n=np.ones(n), p=f_transform, shape=n)

    # Interpolate function values using noiseless covariance matrix
    L = tt.slinalg.cholesky(K_stable)
    f_pred = pm.Deterministic('f_pred', tt.dot(K_s.T, tt.slinalg.solve(L.T, tt.slinalg.solve(L, f_transform))))

    # Use elliptical slice sampling
    ess_step = pm.EllipticalSlice(vars=[f_sample], prior_cov=K_stable)
    trace = pm.sample(5000, start=model.test_point, step=[ess_step], progressbar=False, random_seed=1)

Evaluate posterior fit

The posterior looks good, though the fit is, unsurprisingly, erratic outside the range of the observed data.


In [ ]:
fig, ax = plt.subplots(figsize=(14, 6));
for idx in np.random.randint(4000, 5000, 500):
    ax.plot(X, trace['f_pred'][idx],  alpha=0.04, color='navy')
ax.scatter(X0, f, s=40, color='k');
ax.set_xlim(0, 3);
ax.set_ylim(-0.1, 1.1);

In [ ]:


In [ ]: